{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Getting Started with `mistral-inference`\n",
    "\n",
    "This notebook will guide you through the process of running Mistral models locally. We will cover the following: \n",
    "- How to chat with Mistral 7B Instruct\n",
    "- How to run Mistral 7B Instruct with function calling capabilities\n",
    "\n",
    "We recommend using a GPU such as the A100 to run this notebook. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "G6tXvIsQenpI"
   },
   "outputs": [],
   "source": [
    "!pip install mistral-inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download Mistral 7B Instruct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true
    },
    "id": "4ytmRt0WQeMW"
   },
   "outputs": [],
   "source": [
    "!wget https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-Instruct-v0.3.tar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "eRZg_8wvs5A6"
   },
   "outputs": [],
   "source": [
    "!DIR=$HOME/mistral_7b_instruct_v3 && mkdir -p $DIR && tar -xf mistral-7B-Instruct-v0.3.tar -C $DIR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7CN8gShDf65M"
   },
   "outputs": [],
   "source": [
    "!ls mistral_7b_instruct_v3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Chat with the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "\n",
    "from mistral_inference.transformer import Transformer\n",
    "from mistral_inference.generate import generate\n",
    "\n",
    "from mistral_common.tokens.tokenizers.mistral import MistralTokenizer\n",
    "from mistral_common.protocol.instruct.messages import UserMessage\n",
    "from mistral_common.protocol.instruct.request import ChatCompletionRequest\n",
    "\n",
    "# load tokenizer\n",
    "mistral_tokenizer = MistralTokenizer.from_file(os.path.expanduser(\"~\")+\"/mistral_7b_instruct_v3/tokenizer.model.v3\")\n",
    "# chat completion request\n",
    "completion_request = ChatCompletionRequest(messages=[UserMessage(content=\"Explain Machine Learning to me in a nutshell.\")])\n",
    "# encode message\n",
    "tokens = mistral_tokenizer.encode_chat_completion(completion_request).tokens\n",
    "# load model\n",
    "model = Transformer.from_folder(os.path.expanduser(\"~\")+\"/mistral_7b_instruct_v3\")\n",
    "# generate results\n",
    "out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=mistral_tokenizer.instruct_tokenizer.tokenizer.eos_id)\n",
    "# decode generated tokens\n",
    "result = mistral_tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ce4woS3LkgZ9"
   },
   "source": [
    "## Function calling\n",
    "\n",
    "Mistral 7B Instruct v3 also supports function calling!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TKfPiEwNk1kh"
   },
   "source": [
    "Let's start by creating a function calling example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0PJdwvDEk3dl"
   },
   "outputs": [],
   "source": [
    "from mistral_common.protocol.instruct.messages import UserMessage\n",
    "from mistral_common.protocol.instruct.request import ChatCompletionRequest\n",
    "from mistral_common.protocol.instruct.tool_calls import Function, Tool\n",
    "\n",
    "completion_request = ChatCompletionRequest(\n",
    "    tools=[\n",
    "        Tool(\n",
    "            function=Function(\n",
    "                name=\"get_current_weather\",\n",
    "                description=\"Get the current weather\",\n",
    "                parameters={\n",
    "                    \"type\": \"object\",\n",
    "                    \"properties\": {\n",
    "                        \"location\": {\n",
    "                            \"type\": \"string\",\n",
    "                            \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
    "                        },\n",
    "                        \"format\": {\n",
    "                            \"type\": \"string\",\n",
    "                            \"enum\": [\"celsius\", \"fahrenheit\"],\n",
    "                            \"description\": \"The temperature unit to use. Infer this from the users location.\",\n",
    "                        },\n",
    "                    },\n",
    "                    \"required\": [\"location\", \"format\"],\n",
    "                },\n",
    "            )\n",
    "        )\n",
    "    ],\n",
    "    messages=[\n",
    "        UserMessage(content=\"What's the weather like today in Paris?\"),\n",
    "        ],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bG6ZeZUylpBW"
   },
   "source": [
    "Since we have already loaded the tokenizer and the model in the example above. We will skip these steps here. \n",
    "\n",
    "Now we can encode the message with our tokenizer using `MistralTokenizer`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ii8q-JNClwiq"
   },
   "outputs": [],
   "source": [
    "from mistral_common.tokens.tokenizers.mistral import MistralTokenizer\n",
    "\n",
    "tokens = mistral_tokenizer.encode_chat_completion(completion_request).tokens"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NrueDujkmJT4"
   },
   "source": [
    "and run `generate` to get a response. Don't forget to pass the EOS id!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GWJYO43rl0V8"
   },
   "outputs": [],
   "source": [
    "from mistral_inference.generate import generate\n",
    "\n",
    "out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=mistral_tokenizer.instruct_tokenizer.tokenizer.eos_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "v7baJ1msmPMv"
   },
   "source": [
    "Finally, we can decode the generated tokens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RKhryfBWmHon"
   },
   "outputs": [],
   "source": [
    "result = mistral_tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens)[0]\n",
    "result"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "L4",
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
